{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Taxi CrossValidation with GPU accelerating on XGBoost\n",
    "\n",
    "In this notebook, we will show you how to levarage GPU to accelerate taxi CrossValidation on XGBoost to find out the best model given a group parameters.\n",
    "\n",
    "## Import classes\n",
    "First we need load some common classes that both GPU version and CPU version will use:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel, XGBoostRegressor}\n",
    "import org.apache.spark.ml.evaluation.{RegressionEvaluator}\n",
    "import org.apache.spark.ml.tuning.{ParamGridBuilder,CrossValidator}\n",
    "import org.apache.spark.sql.SparkSession\n",
    "import org.apache.spark.sql.types.{FloatType, IntegerType, StructField, StructType}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "what is new to xgboost-spark users is rapids.GpuDataReader and **rapids.CrossValidator**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "ename": "Syntax Error.",
     "evalue": "",
     "output_type": "error",
     "traceback": []
    }
   ],
   "source": [
    "// import ml.dmlc.xgboost4j.scala.spark.rapids.CrossValidator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set dataset path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dataRoot = /data\n",
       "trainParquetPath = /data/taxi/parquet/train\n",
       "evalParquetPath = /data/taxi/parquet/eval\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "/data/taxi/parquet/eval"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "// You need to update them to your real paths! The input data files can be the output of taxi-etl jobs, or you can\n",
    "// just use the provided sample datasets under datasets path. \n",
    "val dataRoot = sys.env.getOrElse(\"DATA_ROOT\", \"/data\")\n",
    "val trainParquetPath=dataRoot + \"/taxi/parquet/train\"\n",
    "val evalParquetPath=dataRoot + \"/taxi/parquet/eval\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set the schema of the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "labelColName = fare_amount\n",
       "schema = StructType(StructField(vendor_id,FloatType,true), StructField(passenger_count,FloatType,true), StructField(trip_distance,FloatType,true), StructField(pickup_longitude,FloatType,true), StructField(pickup_latitude,FloatType,true), StructField(rate_code,FloatType,true), StructField(store_and_fwd,FloatType,true), StructField(dropoff_longitude,FloatType,true), StructField(dropoff_latitude,FloatType,true), StructField(fare_amount,FloatType,true), StructField(hour,FloatType,true), StructField(year,IntegerType,true), StructField(month,IntegerType,true), StructField(day,FloatType,true), StructField(day_of_week,FloatType,true), StructField(is_weekend,FloatType,true))\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "StructType(StructField(vendor_id,FloatType,true), StructField(passenger_count,FloatType,true), StructField(trip_distance,FloatType,true), StructField(pickup_longitude,FloatType,true), StructField(pickup_latitude,FloatType,true), StructField(rate_code,FloatType,true), StructField(store_and_fwd,FloatType,true), StructField(dropoff_longitude,FloatType,true), StructField(dropoff_latitude,FloatType,true), StructField(fare_amount,FloatType,true), StructField(hour,FloatType,true), StructField(year,IntegerType,true), StructField(month,IntegerType,true), StructField(day,FloatType,true), StructField(day_of_week,FloatType,true), StructField(is_weekend,FloatType,true))"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val labelColName = \"fare_amount\"\n",
    "val schema =\n",
    "    StructType(Array(\n",
    "      StructField(\"vendor_id\", FloatType),\n",
    "      StructField(\"passenger_count\", FloatType),\n",
    "      StructField(\"trip_distance\", FloatType),\n",
    "      StructField(\"pickup_longitude\", FloatType),\n",
    "      StructField(\"pickup_latitude\", FloatType),\n",
    "      StructField(\"rate_code\", FloatType),\n",
    "      StructField(\"store_and_fwd\", FloatType),\n",
    "      StructField(\"dropoff_longitude\", FloatType),\n",
    "      StructField(\"dropoff_latitude\", FloatType),\n",
    "      StructField(labelColName, FloatType),\n",
    "      StructField(\"hour\", FloatType),\n",
    "      StructField(\"year\", IntegerType),\n",
    "      StructField(\"month\", IntegerType),\n",
    "      StructField(\"day\", FloatType),\n",
    "      StructField(\"day_of_week\", FloatType),\n",
    "      StructField(\"is_weekend\", FloatType)\n",
    "    ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a new spark session and load data\n",
    "we must create a new spark session to continue all spark operations. It will also be used to initilize the `GpuDataReader` which is a data reader powered by GPU.\n",
    "\n",
    "NOTE: in this notebook, we have uploaded dependency jars when installing toree kernel. If we don't upload them at installation time, we can also upload in notebook by [%AddJar magic](https://toree.incubator.apache.org/docs/current/user/faq/). However, there's one restriction for `%AddJar`: the jar uploaded can only be available when `AddJar` is called after a new spark session is created. We must use it as below:\n",
    "\n",
    "```scala\n",
    "import org.apache.spark.sql.SparkSession\n",
    "val spark = SparkSession.builder().appName(\"Taxi-GPU-CV\").getOrCreate\n",
    "%AddJar file:/data/libs/rapids-4-spark-XXX.jar\n",
    "%AddJar file:/data/libs/xgboost4j-spark-gpu_2.12-XXX.jar\n",
    "%AddJar file:/data/libs/xgboost4j-gpu_2.12-XXX.jar\n",
    "// ...\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "spark = org.apache.spark.sql.SparkSession@1b953a9c\n",
       "trainDs = [vendor_id: int, passenger_count: int ... 15 more fields]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[vendor_id: int, passenger_count: int ... 15 more fields]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val spark = SparkSession.builder().appName(\"taxi-gpu-cv\").getOrCreate()\n",
    "val trainDs = spark.read.parquet(trainParquetPath)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Find out features to train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "featureNames = Array(vendor_id, passenger_count, trip_distance, pickup_longitude, pickup_latitude, rate_code, store_and_fwd, dropoff_longitude, dropoff_latitude, hour, year, month, day, day_of_week, is_weekend)\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Array(vendor_id, passenger_count, trip_distance, pickup_longitude, pickup_latitude, rate_code, store_and_fwd, dropoff_longitude, dropoff_latitude, hour, year, month, day, day_of_week, is_weekend)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val featureNames = schema.filter(_.name != labelColName).map(_.name).toArray"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "regressorParam = Map(num_round -> 100, tree_method -> gpu_hist, num_workers -> 1)\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Map(num_round -> 100, tree_method -> gpu_hist, num_workers -> 1)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val regressorParam = Map(\n",
    "    \"num_round\" -> 100,\n",
    "    \"tree_method\" -> \"gpu_hist\",\n",
    "    \"num_workers\" -> 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construct CrossValidator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "regressor = xgbr_1c1bd6fa3a5f\n",
       "paramGrid = \n",
       "evaluator = RegressionEvaluator: uid=regEval_c7293a967512, metricName=rmse, throughOrigin=false\n",
       "cv = cv_06528fc9d704\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Array({\n",
       "\txgbr_1c1bd6fa3a5f-eta: 0.2,\n",
       "\txgbr_1c1bd6fa3a5f-maxDepth: 3\n",
       "}, {\n",
       "\txgbr_1c1bd6fa3a5f-eta: 0.6,\n",
       "\txgbr_1c1bd6fa3a5f-maxDepth: 3\n",
       "}, {\n",
       "\txgbr_1c1bd6fa3a5f-eta: 0.2,\n",
       "\txgbr_1c1bd6fa3a5f-maxDepth: 10\n",
       "}, {\n",
       "\txgbr_1c1bd6fa3a5f-eta: 0.6,\n",
       "\txgbr_1c1bd6fa3a5f-maxDepth: 10\n",
       "})\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "cv_06528fc9d704"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val regressor = new XGBoostRegressor(regressorParam)\n",
    "    .setLabelCol(labelColName)\n",
    "    .setFeaturesCol(featureNames)\n",
    "val paramGrid = new ParamGridBuilder()\n",
    "    .addGrid(regressor.maxDepth, Array(3, 10))\n",
    "    .addGrid(regressor.eta, Array(0.2, 0.6))\n",
    "    .build()\n",
    "val evaluator = new RegressionEvaluator().setLabelCol(labelColName)\n",
    "val cv = new CrossValidator()\n",
    "    .setEstimator(regressor)\n",
    "    .setEvaluator(evaluator)\n",
    "    .setEstimatorParamMaps(paramGrid)\n",
    "    .setNumFolds(3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## train with CrossValidator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=36551, DMLC_NUM_WORKER=1}\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=40153, DMLC_NUM_WORKER=1}\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=46553, DMLC_NUM_WORKER=1}\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=50795, DMLC_NUM_WORKER=1}\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=44927, DMLC_NUM_WORKER=1}\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=55309, DMLC_NUM_WORKER=1}\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=55163, DMLC_NUM_WORKER=1}\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=54783, DMLC_NUM_WORKER=1}\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=49873, DMLC_NUM_WORKER=1}\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=36003, DMLC_NUM_WORKER=1}\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=41429, DMLC_NUM_WORKER=1}\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=60783, DMLC_NUM_WORKER=1}\n",
      "Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.19.183.210, DMLC_TRACKER_PORT=49361, DMLC_NUM_WORKER=1}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "model = xgbr_1c1bd6fa3a5f\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "xgbr_1c1bd6fa3a5f"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val model = cv.fit(trainDs).bestModel.asInstanceOf[XGBoostRegressionModel]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tranform with best model trained by CrossValidator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "transformDs = [vendor_id: int, passenger_count: int ... 15 more fields]\n",
       "df = [vendor_id: int, passenger_count: int ... 16 more fields]\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------+------------------+\n",
      "|fare_amount|        prediction|\n",
      "+-----------+------------------+\n",
      "|       11.4|12.278875350952148|\n",
      "|        7.4|7.4439215660095215|\n",
      "|        5.0| 4.565710067749023|\n",
      "|        8.5| 9.188780784606934|\n",
      "|        7.4| 7.266360759735107|\n",
      "+-----------+------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[vendor_id: int, passenger_count: int ... 16 more fields]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val transformDs = spark.read.parquet(evalParquetPath)\n",
    "val df = model.transform(transformDs).cache()\n",
    "df.select(\"fare_amount\", \"prediction\").show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "evaluator = RegressionEvaluator: uid=regEval_1c57378a8fe1, metricName=rmse, throughOrigin=false\n",
       "rmse = 2.2492672858545992\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "2.2492672858545992"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val evaluator = new RegressionEvaluator().setLabelCol(labelColName)\n",
    "val rmse = evaluator.evaluate(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "spark.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "XGBoost4j-Spark - Scala",
   "language": "scala",
   "name": "XGBoost4j-Spark_scala"
  },
  "language_info": {
   "codemirror_mode": "text/x-scala",
   "file_extension": ".scala",
   "mimetype": "text/x-scala",
   "name": "scala",
   "pygments_lexer": "scala",
   "version": "2.12.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}